{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "github.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O5s4Uvtx0AnL"
      },
      "outputs": [],
      "source": [
        "cd drive/MyDrive/"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import os\n",
        "from sklearn.model_selection import train_test_split"
      ],
      "metadata": {
        "id": "LqyYaqfZnIZ_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "label_dict = {'药物':'DRUG',\n",
        "              '解剖部位':'BODY',\n",
        "              '疾病和诊断':'DISEASES',\n",
        "              '影像检查':'EXAMINATIONS',\n",
        "              '实验室检验':'TEST',\n",
        "              '手术':'TREATMENT'}\n",
        "\n",
        "TRAIN = './CCKS_2019_Task1/processed_data/train_dataset.txt'\n",
        "VALID = './CCKS_2019_Task1/processed_data/val_dataset.txt'\n",
        "TEST = './CCKS_2019_Task1/processed_data/test_dataset.txt'\n",
        "\n",
        "def sentence2BIOlabel(sentence, label_from_file):\n",
        "    \"\"\" BIO Tagging \"\"\"\n",
        "    sentence_label = ['O']*len(sentence)\n",
        "    if label_from_file=='':\n",
        "        return sentence_label\n",
        "    \n",
        "    for line in label_from_file.split('\\n'):\n",
        "        \n",
        "        entity_info = line.strip().split('\\t')\n",
        "        start_index = int(entity_info[1])     \n",
        "        end_index = int(entity_info[2])      \n",
        "        entity_label = label_dict[entity_info[3]]      \n",
        "        # Frist entity: B-xx\n",
        "        sentence_label[start_index] = 'B-'+entity_label\n",
        "        # Other: I-xx\n",
        "        for i in range(start_index+1, end_index):\n",
        "            sentence_label[i] = 'I-'+entity_label\n",
        "    return sentence_label\n",
        "\n",
        "def loadRawData(fileName):\n",
        "    \"\"\" Loading raw data and tagging \"\"\"\n",
        "    sentence_list = []\n",
        "    label_list = []\n",
        "\n",
        "    for file_name in os.listdir(fileName):\n",
        "    \n",
        "        if '.DS_Store' == file_name:\n",
        "            continue\n",
        "\n",
        "        if 'original' in file_name:\n",
        "            org_file = fileName + file_name\n",
        "            lab_file = fileName + file_name.replace('-original', '')\n",
        "\n",
        "            with open(org_file, encoding='utf-8') as f:\n",
        "                content = f.read().strip()\n",
        "\n",
        "            with open(lab_file, encoding='utf-8') as f:\n",
        "                content_label = f.read().strip()\n",
        "\n",
        "            sentence_label = sentence2BIOlabel(content, content_label)\n",
        "            sentence_list.append(content)\n",
        "            label_list.append(sentence_label)\n",
        "\n",
        "    return sentence_list, label_list\n",
        "\n",
        "def Save_data(filename, texts, tags):\n",
        "  \"\"\" Processing to files in neeed format \"\"\"\n",
        "  with open(filename, 'w') as f:\n",
        "    for sent, tag in zip(texts, tags):\n",
        "        size = len(sent)\n",
        "        for i in range(size):\n",
        "          f.write(sent[i])\n",
        "          f.write('\\t')\n",
        "          f.write(tag[i])\n",
        "          f.write('\\n')"
      ],
      "metadata": {
        "id": "1yw17zB0nmzb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Training data\n",
        "sentence_list, label_list = loadRawData('./CCKS_2019_Task1/data/')\n",
        "# Test data\n",
        "sentence_list_test, label_list_test = loadRawData('./CCKS_2019_Task1/data_test/')\n",
        "\n",
        "# Split dataset\n",
        "words = [list(sent) for sent in sentence_list]\n",
        "t_words = [list(sent) for sent in sentence_list_test]\n",
        "tags = label_list\n",
        "t_tags = label_list_test\n",
        "train_texts, val_texts, train_tags, val_tags = train_test_split(words, tags, test_size=.2)\n",
        "test_texts, test_tags = t_words, t_tags\n",
        "\n",
        "# Obtain training, validating and testing files\n",
        "Save_data(TRAIN, train_texts, train_tags)\n",
        "Save_data(VALID, val_texts, val_tags)\n",
        "Save_data(TEST, test_texts, test_tags)"
      ],
      "metadata": {
        "id": "fGZUEmdBo0BK"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pip install transformers==3.4"
      ],
      "metadata": {
        "id": "EvcXRCPxo3Ln"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "pip install pytorch-crf"
      ],
      "metadata": {
        "id": "kSi5VYRNo3Nx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!python main.py --n_epochs 30"
      ],
      "metadata": {
        "id": "FonoUHo1o3Pf"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}